import os
import random

import numpy as np

from allennlp.data import DataLoader

from dataset import *
from victim_model.text_classifier import TextClassifier
from victim_model.text_predictor import TextPredictor, TextPredictorEnsembler

from config import Config
from tools.color import Color
from tools.logger import Logger
from tools.saver import Saver
from tools.device_manager import DeviceManager
from tools.utils import write_json, read_json

cf = Config()


def build_data_reader():
    if cf.dataset == 'imdb':
        Reader = IMDBDatasetReader
    elif cf.dataset == 'agnews':
        Reader = AGNewsDatasetReader
    elif cf.dataset == 'mr':
        Reader = MRDatasetReader
    else:
        raise ValueError(f'{cf.dataset} not implemented. Only support: imdb and agnews.')
    return Reader


def build_predictor(Reader):
    encoder = cf.encoder.split(',')
    token = cf.token.split(',')

    predictors = []
    for e, t in zip(encoder, token):
        reader = Reader(cf, token_type=t)
        saver = Saver(f'{cf.dataset}_{e}_{t}', d_ckpt='ckpt2')
        model = saver.load_last_epoch(TextClassifier, {'cf': cf, 'encoder_type': e, 'token_type': t})
        model = model.cuda()
        f = TextPredictor(model, reader)
        predictors.append(f)

    if len(predictors) == 1:
        predictor = predictors[0]
    else:
        predictor = TextPredictorEnsembler(predictors)
    return predictor


def test_predictor(dataset, predictor):
    srs = []
    for i in range(0, len(dataset), cf.batch_size):
        batch_text = dataset[i:i + cf.batch_size]
        outputs = predictor.predict_batch_json(batch_text)
        pred = [o['pred'] for o in outputs]
        gold = [o['gold'] for o in outputs]
        sr = [abs(p - g) > 1e-5 for p, g in zip(pred, gold)]
        srs.extend(sr)

    return srs


def get_adv_ids(adv_type):
    if cf.adv_id == 'all':
        adv_ids = []
        for f in os.listdir(f'adv/{adv_type}'):
            if cf.dataset in f and cf.attacker in f:
                adv_ids.append(f)
        return adv_ids
    elif cf.adv_id == 'ensemble':
        adv_ids = []
        for f in os.listdir(f'adv/{adv_type}'):
            if cf.dataset in f and cf.attacker in f and ',' in f:
                adv_ids.append(f)
        return adv_ids
    elif cf.adv_id == 'single':
        adv_ids = []
        for f in os.listdir(f'adv/{adv_type}'):
            if cf.dataset in f and cf.attacker in f and ',' not in f:
                adv_ids.append(f)
        return adv_ids
    else:
        return [cf.adv_id]


def read_length(adv_id):
    lengths = []
    detail = read_json(f'adv/detail/{adv_id}')
    for data in detail:
        lengths.append(data['length'])
    return lengths


def main():
    logger = Logger(cf.p_log['transfer'], quiet=cf.quiet)
    logger.print(cf)

    Reader = build_data_reader()
    predictor = build_predictor(Reader)
    # FUCK
    adv_type = 'semi_adv_examples'
    adv_ids = get_adv_ids(adv_type)

    for adv_id in adv_ids:
        logger.print(Color.green('reading data...'))
        reader = Reader(cf, token_type='word')
        victim_dataset = reader.read_json(f'adv/{adv_type}/{adv_id}')

        # FUCK
        victim_dataset = victim_dataset[:cf.attack_dataset_size]

        logger.print(Color.green('attack model...'))
        srs = test_predictor(victim_dataset, predictor)
        # FUCK
        lengths = read_length(adv_id)
        new_lengths = []
        for i, sr in enumerate(srs):
            if sr:
                new_lengths.append(lengths[i])

        result = f'[{adv_id} {cf.model_id}] SR: {np.average(srs)} {np.average(new_lengths)} '
        logger.print(Color.red(result))
        logger.log(result)


if __name__ == '__main__':
    with DeviceManager(cf.device):
        main()
